{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ME-NET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.utils.data as Data\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from defense import MENet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Device configuration\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "torch.cuda.set_device(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor()])\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
    "                                       download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=4,\n",
    "                                         shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Acc: 85.240%\n"
     ]
    }
   ],
   "source": [
    "# model trained with matrix-estimation + adversarial training\n",
    "checkpoint = torch.load('models/ckpt.t7_ResNet18_advtrain_concat_usvt_0.5_white')\n",
    "\n",
    "# model trained with matrix-estimation + standard training\n",
    "#checkpoint = torch.load('models/ckpt.t7_ResNet18_pure_concat_usvt_0.5_white')\n",
    "\n",
    "model = checkpoint['model']\n",
    "model.device_ids = [0]\n",
    "model.output_device = 0\n",
    "rng_state = checkpoint['rng_state']\n",
    "torch.set_rng_state(rng_state)\n",
    "\n",
    "model = model.to(device)\n",
    "menet_model = MENet(model)\n",
    "\n",
    "model.eval()\n",
    "menet_model.eval()\n",
    "\n",
    "# check model accuracy\n",
    "def test_generalization(model, loader):\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (inputs, targets) in enumerate(loader):\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            with torch.no_grad():\n",
    "                outputs = model(inputs)\n",
    "\n",
    "            _, pred_idx = torch.max(outputs.data, 1)\n",
    "            total += targets.size(0)\n",
    "            correct += pred_idx.eq(targets.data).cpu().sum().float()\n",
    "\n",
    "    return 100. * correct / total\n",
    "print('Acc: %.3f%%' % test_generalization(menet_model, testloader))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attack 1: BPDA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 0.555986 0.145\n",
      "0 20 0.89678466 0.24\n",
      "0 40 1.2432835 0.29\n",
      "0 60 1.5040343 0.335\n",
      "0 80 2.0296295 0.39\n",
      "0 100 2.333188 0.42\n",
      "0 120 2.643069 0.45\n",
      "0 140 2.9956958 0.485\n",
      "0 160 3.1431735 0.485\n",
      "0 180 3.416809 0.515\n",
      "1 0 0.6715832 0.125\n",
      "1 20 1.0303277 0.23\n",
      "1 40 1.3949406 0.285\n",
      "1 60 1.8998712 0.33\n",
      "1 80 2.1450403 0.39\n",
      "1 100 2.454036 0.435\n",
      "1 120 2.68185 0.44\n",
      "1 140 2.965865 0.455\n",
      "1 160 3.1352801 0.495\n",
      "1 180 3.3462453 0.485\n",
      "2 0 0.45458162 0.135\n",
      "2 20 0.7306506 0.185\n",
      "2 40 1.0998504 0.25\n",
      "2 60 1.4138063 0.31\n",
      "2 80 1.7645721 0.37\n",
      "2 100 2.152749 0.405\n",
      "2 120 2.4440153 0.45\n",
      "2 140 2.8014889 0.475\n",
      "2 160 2.8414094 0.475\n",
      "2 180 3.1407578 0.485\n",
      "3 0 0.7717588 0.16\n",
      "3 20 1.0925505 0.255\n",
      "3 40 1.4218915 0.28\n",
      "3 60 1.8660396 0.355\n",
      "3 80 2.238803 0.39\n",
      "3 100 2.7191668 0.43\n",
      "3 120 2.8068614 0.48\n",
      "3 140 3.1176968 0.475\n",
      "3 160 3.5125623 0.515\n",
      "3 180 3.413438 0.515\n",
      "4 0 0.46111125 0.125\n",
      "4 20 0.8502369 0.21\n",
      "4 40 1.0776811 0.265\n",
      "4 60 1.4680864 0.31\n",
      "4 80 1.8179624 0.365\n",
      "4 100 2.2098296 0.38\n",
      "4 120 2.3671854 0.395\n",
      "4 140 2.5689921 0.425\n",
      "4 160 2.9055276 0.45\n",
      "4 180 3.016227 0.48\n"
     ]
    }
   ],
   "source": [
    "epsilon = 8.0/255\n",
    "num_steps = 200\n",
    "step_size = (5 * epsilon) / num_steps\n",
    "\n",
    "attack_model = menet_model\n",
    "attack_model.eval()\n",
    "\n",
    "batch_size = 200\n",
    "\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n",
    "                                         shuffle=False, num_workers=2)\n",
    "testloader = iter(testloader)\n",
    "\n",
    "num_batches = 1000 // batch_size\n",
    "\n",
    "x_adv = []\n",
    "\n",
    "for i in range(num_batches):\n",
    "    (x_test, y_test) = next(testloader)\n",
    "    x_test = x_test.to(device)\n",
    "    y_test = y_test.to(device)\n",
    "\n",
    "    x = x_test.detach()\n",
    "    x = x + torch.zeros_like(x).uniform_(-epsilon, epsilon)\n",
    "    x = torch.clamp(x, 0, 1)\n",
    "\n",
    "    for j in range(num_steps):\n",
    "        x.requires_grad_()\n",
    "        with torch.enable_grad():\n",
    "            logits = attack_model(x)\n",
    "            _, pred_idx = torch.max(logits.data, 1)\n",
    "            loss = F.cross_entropy(logits, y_test, reduce=False)\n",
    "        grad = torch.autograd.grad(torch.mean(loss), [x])[0]\n",
    "\n",
    "        success = ~(pred_idx.eq(y_test.data).cpu().numpy().astype(np.bool))\n",
    "\n",
    "        if j % 20 == 0:\n",
    "            print(i, j, torch.mean(loss).detach().cpu().numpy(), np.mean(success))\n",
    "\n",
    "        x = x.detach() + step_size * torch.sign(grad.detach())\n",
    "        x = torch.min(torch.max(x, x_test - epsilon), x_test + epsilon)\n",
    "        x = torch.clamp(x, 0, 1)\n",
    "    \n",
    "    x_adv.append(x.detach().cpu().numpy())\n",
    "\n",
    "x_adv = np.concatenate(x_adv, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc: 49.7%\n"
     ]
    }
   ],
   "source": [
    "# evaluate attack\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n",
    "                                         shuffle=False, num_workers=0)\n",
    "testloader = iter(testloader)\n",
    "\n",
    "correct = []\n",
    "with torch.no_grad():\n",
    "    for i in range(num_batches):\n",
    "        (_, y_test) = next(testloader)\n",
    "        y_test = y_test.to(device)\n",
    "        x = x_adv[i*batch_size:(i+1)*batch_size]\n",
    "        outputs_adv = menet_model(torch.from_numpy(x).to(device))\n",
    "        loss = F.cross_entropy(outputs_adv, y_test, reduce=False)\n",
    "        _, pred_idx = torch.max(outputs_adv.data, 1)\n",
    "        correct.append(pred_idx.eq(y_test.data).cpu().float())\n",
    "\n",
    "print(\"acc: {:.1f}%\".format(100 * np.mean(np.concatenate(correct))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Attack 2: BPDA + EOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 0.6040787 0.60407865 0.115 0.115\n",
      "0 20 2.131285 2.156087 0.42 0.47\n",
      "0 40 4.044399 4.094599 0.555 0.625\n",
      "0 60 5.54005 5.631854 0.645 0.735\n",
      "0 80 6.3823624 6.5319157 0.715 0.81\n",
      "0 100 6.6823926 6.9006596 0.725 0.845\n",
      "0 120 6.7796726 7.075388 0.73 0.855\n",
      "0 140 6.8647156 7.1857057 0.73 0.865\n",
      "0 160 6.9455533 7.2675524 0.725 0.87\n",
      "0 180 6.968773 7.3123436 0.735 0.875\n",
      "1 0 0.7038498 0.7038498 0.15 0.15\n",
      "1 20 2.1949828 2.2188222 0.385 0.46\n",
      "1 40 3.9998465 4.055197 0.55 0.63\n",
      "1 60 5.512513 5.5912766 0.66 0.75\n",
      "1 80 6.340132 6.477013 0.73 0.805\n",
      "1 100 6.630049 6.848197 0.73 0.835\n",
      "1 120 6.7330837 7.0133104 0.75 0.85\n",
      "1 140 6.8331246 7.1317973 0.74 0.855\n",
      "1 160 6.892745 7.213182 0.765 0.86\n",
      "1 180 6.9034123 7.264055 0.76 0.86\n",
      "2 0 0.4001552 0.40015516 0.11 0.11\n",
      "2 20 1.883779 1.9113045 0.345 0.44\n",
      "2 40 3.9070714 3.9436333 0.545 0.62\n",
      "2 60 5.4797034 5.567206 0.635 0.74\n",
      "2 80 6.32311 6.469389 0.71 0.79\n",
      "2 100 6.6210656 6.827215 0.72 0.825\n",
      "2 120 6.8080587 7.0396833 0.71 0.86\n",
      "2 140 6.899159 7.171466 0.73 0.87\n",
      "2 160 6.9489408 7.255375 0.735 0.88\n",
      "2 180 6.951769 7.305923 0.74 0.89\n",
      "3 0 0.73021567 0.73021567 0.155 0.155\n",
      "3 20 2.3652809 2.3808732 0.425 0.455\n",
      "3 40 4.326644 4.3924384 0.565 0.65\n",
      "3 60 5.9239182 6.014492 0.645 0.745\n",
      "3 80 6.787075 6.937279 0.735 0.795\n",
      "3 100 7.072802 7.3109035 0.735 0.83\n",
      "3 120 7.2324247 7.506997 0.745 0.845\n",
      "3 140 7.3321652 7.6243815 0.745 0.855\n",
      "3 160 7.3685503 7.706116 0.75 0.86\n",
      "3 180 7.400965 7.75485 0.745 0.86\n",
      "4 0 0.48596132 0.48596132 0.15 0.15\n",
      "4 20 1.9052229 1.9258837 0.36 0.41\n",
      "4 40 3.7918108 3.84049 0.53 0.615\n",
      "4 60 5.340828 5.409531 0.64 0.72\n",
      "4 80 6.188445 6.320548 0.69 0.75\n",
      "4 100 6.468932 6.6547117 0.725 0.78\n",
      "4 120 6.5909534 6.8436766 0.695 0.8\n",
      "4 140 6.69826 6.9681983 0.71 0.82\n",
      "4 160 6.72647 7.048544 0.695 0.825\n",
      "4 180 6.766426 7.1018615 0.695 0.83\n"
     ]
    }
   ],
   "source": [
    "epsilon = 8.0/255\n",
    "num_steps = 200\n",
    "step_size = (5 * epsilon) / num_steps\n",
    "\n",
    "attack_model = menet_model\n",
    "attack_model.eval()\n",
    "\n",
    "batch_size = 200\n",
    "\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n",
    "                                         shuffle=False, num_workers=0)\n",
    "testloader = iter(testloader)\n",
    "\n",
    "num_batches = 1000 // batch_size\n",
    "\n",
    "x_adv = []\n",
    "\n",
    "for i in range(num_batches):\n",
    "    (x_test, y_test) = next(testloader)\n",
    "    x_test = x_test.to(device)\n",
    "    y_test = y_test.to(device)\n",
    "    \n",
    "    x_best = x_test.cpu().numpy().copy()\n",
    "    best_loss = -100*np.ones(len(x_best), dtype=np.float32)\n",
    "    all_success = np.zeros(len(x_best), dtype=np.bool)\n",
    "\n",
    "    x = x_test.detach()\n",
    "    x = x + torch.zeros_like(x).uniform_(-epsilon, epsilon)\n",
    "    x = torch.clamp(x, 0, 1)\n",
    "\n",
    "    for j in range(num_steps):\n",
    "        x.requires_grad_()\n",
    "        with torch.enable_grad():\n",
    "            logits = attack_model(x)\n",
    "            _, pred_idx = torch.max(logits.data, 1)\n",
    "            loss = F.cross_entropy(logits, y_test, reduce=False)\n",
    "        grad = torch.autograd.grad(torch.mean(loss), [x])[0]\n",
    "\n",
    "        # take the average gradient\n",
    "        num_r = 40\n",
    "        for r in range(num_r):\n",
    "            x.requires_grad_()\n",
    "            with torch.enable_grad():\n",
    "                logits = attack_model(x)\n",
    "                loss_r = F.cross_entropy(logits, y_test, reduce=False)\n",
    "            grad_r = torch.autograd.grad(torch.mean(loss_r), [x])[0]\n",
    "            grad += grad_r\n",
    "            loss += loss_r\n",
    "\n",
    "        loss /= (1.0 + num_r)\n",
    "\n",
    "        # retain the example with the highest average loss\n",
    "        success = ~(pred_idx.eq(y_test.data).cpu().numpy().astype(np.bool))\n",
    "        better = loss.detach().cpu().numpy() > best_loss\n",
    "        if np.any(better):\n",
    "            x_best[better] = x.detach().cpu().numpy()[better]\n",
    "            best_loss[better] = loss.detach().cpu().numpy()[better]\n",
    "\n",
    "        all_success |= success\n",
    "\n",
    "        if j % 20 == 0:\n",
    "            print(i, j, torch.mean(loss).detach().cpu().numpy(), np.mean(best_loss), np.mean(success), np.mean(all_success))\n",
    "\n",
    "        x = x.detach() + step_size * torch.sign(grad.detach())\n",
    "        x = torch.min(torch.max(x, x_test - epsilon), x_test + epsilon)\n",
    "        x = torch.clamp(x, 0, 1)\n",
    "    \n",
    "    x_adv.append(x_best)\n",
    "x_adv = np.concatenate(x_adv, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc: 26.0%\n"
     ]
    }
   ],
   "source": [
    "# evaluate attack\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n",
    "                                         shuffle=False, num_workers=0)\n",
    "testloader = iter(testloader)\n",
    "\n",
    "correct = []\n",
    "with torch.no_grad():\n",
    "    for i in range(num_batches):\n",
    "        (_, y_test) = next(testloader)\n",
    "        y_test = y_test.to(device)\n",
    "        x = x_adv[i*batch_size:(i+1)*batch_size]\n",
    "        outputs_adv = menet_model(torch.from_numpy(x).to(device))\n",
    "        loss = F.cross_entropy(outputs_adv, y_test, reduce=False)\n",
    "        _, pred_idx = torch.max(outputs_adv.data, 1)\n",
    "        correct.append(pred_idx.eq(y_test.data).cpu().float())\n",
    "\n",
    "print(\"acc: {:.1f}%\".format(100 * np.mean(np.concatenate(correct))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ME",
   "language": "python",
   "name": "me"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
